In [1]:
# Define experiment parameters
year = "200506"
target_col = "blue_collar"  # 'white_collar', 'blue_collar', 'has_occ'
sample_weight_col = 'women_weight'
In [2]:
# Define resource utilization parameters
random_state = 42
n_jobs_clf = 16
n_jobs_cv = 4
cv_folds = 5
In [3]:
import numpy as np
np.random.seed(random_state)

import pandas as pd
pd.set_option('display.max_columns', 500)

import matplotlib.pylab as pl

from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score

from sklearn.utils.class_weight import compute_class_weight

import lightgbm
from lightgbm import LGBMClassifier

from sklearn.model_selection import train_test_split
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.model_selection import StratifiedKFold

import shap

import pickle
from joblib import dump, load

Prepare Dataset

In [4]:
# Load dataset
dataset = pd.read_csv(f"data/women_work_data_{year}.csv")
print("Loaded dataset: ", dataset.shape)
dataset.head()
Loaded dataset:  (116417, 25)
Out[4]:
case_id_str country_code cluster_no hh_no line_no state hh_religion caste occupation no_occ white_collar blue_collar has_occ women_weight urban women_anemic total_children obese_female age years_edu wealth_index hh_members no_children_below5 freq_tv year
0 28 1 1 2 IA5 28001 1 2 [ap] andhra pradesh hindu scheduled tribe 0 1 0 0 0 2.474583 0 NaN 4 NaN 47 0 2 6 2 0.0 2005
1 28 1 1 4 IA5 28001 1 4 [ap] andhra pradesh hindu scheduled tribe 5 0 0 1 1 2.474583 0 0.0 2 0.0 23 2 2 6 2 3.0 2005
2 28 1 2 2 IA5 28001 2 2 [ap] andhra pradesh hindu other backward class 5 0 0 1 1 2.474583 0 0.0 1 0.0 20 0 2 5 0 3.0 2005
3 28 1 2 5 IA5 28001 2 5 [ap] andhra pradesh hindu other backward class 5 0 0 1 1 2.474583 0 1.0 0 0.0 40 0 2 5 0 2.0 2005
4 28 1 4 3 IA5 28001 4 3 [ap] andhra pradesh hindu other backward class 8 0 0 1 1 2.474583 0 0.0 2 0.0 31 7 3 5 0 3.0 2005
In [5]:
# See distribution of target values
print("Target column distribution:\n", dataset[target_col].value_counts(dropna=False))
Target column distribution:
 0    78217
1    38200
Name: blue_collar, dtype: int64
In [6]:
# Drop samples where the target is missing
dataset.dropna(axis=0, subset=[target_col, sample_weight_col], inplace=True)
print("Drop missing targets: ", dataset.shape)
Drop missing targets:  (116417, 25)
In [7]:
# Drop samples where age < 21
dataset = dataset[dataset['age'] >= 21]
print("Drop under-21 samples: ", dataset.shape)
Drop under-21 samples:  (89163, 25)
In [8]:
# See new distribution of target values
print("Target column distribution:\n", dataset[target_col].value_counts(dropna=False))
Target column distribution:
 0    58293
1    30870
Name: blue_collar, dtype: int64
In [9]:
# Post-processing

# Group SC/ST castes together
dataset['caste'][dataset['caste'] == 'scheduled caste'] = 'sc/st'
dataset['caste'][dataset['caste'] == 'scheduled tribe'] = 'sc/st'
if year == "200506":
    dataset['caste'][dataset['caste'] == '9'] = "don\'t know"

# Fix naming for General caste
dataset['caste'][dataset['caste'] == 'none of above'] = 'general'

if year == "201516":
    # Convert wealth index from str to int values
    wi_dict = {'poorest': 0, 'poorer': 1, 'middle': 2, 'richer': 3, 'richest': 4}
    dataset['wealth_index'] = [wi_dict[wi] for wi in dataset['wealth_index']]
/home/chaitanya/miniconda3/envs/tf_gpu/lib/python3.6/site-packages/ipykernel_launcher.py:5: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  """
/home/chaitanya/miniconda3/envs/tf_gpu/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  import sys
/home/chaitanya/miniconda3/envs/tf_gpu/lib/python3.6/site-packages/ipykernel_launcher.py:10: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  # Remove the CWD from sys.path while we load stuff.
In [10]:
# Define feature columns
x_cols_categorical = ['state', 'hh_religion', 'caste']
x_cols_binary = ['urban', 'women_anemic', 'obese_female']
x_cols_numeric = ['age', 'years_edu', 'wealth_index', 'hh_members', 'no_children_below5', 'total_children', 'freq_tv']
x_cols = x_cols_categorical + x_cols_binary + x_cols_numeric
print("Feature columns:\n", x_cols)
Feature columns:
 ['state', 'hh_religion', 'caste', 'urban', 'women_anemic', 'obese_female', 'age', 'years_edu', 'wealth_index', 'hh_members', 'no_children_below5', 'total_children', 'freq_tv']
In [11]:
# Drop samples with missing values in feature columns
dataset.dropna(axis=0, subset=x_cols, inplace=True)
print("Drop missing feature value rows: ", dataset.shape)
Drop missing feature value rows:  (77762, 25)
In [12]:
# Separate target column
targets = dataset[target_col]
# Separate sampling weight column
sample_weights = dataset[sample_weight_col]
# Drop columns which are not part of features
dataset.drop(columns=[col for col in dataset.columns if col not in x_cols], axis=1, inplace=True)
print("Drop extra columns: ", dataset.shape)
Drop extra columns:  (77762, 13)
In [13]:
# Obtain one-hot encodings for the caste column
dataset = pd.get_dummies(dataset, columns=['caste'])
x_cols_categorical.remove('caste')  # Remove 'caste' from categorical variables list
print("Caste to one-hot: ", dataset.shape)
Caste to one-hot:  (77762, 16)
In [14]:
dataset_display = dataset.copy()
print("Create copy for visualization: ", dataset_display.shape)
dataset_display.head()
Create copy for visualization:  (77762, 16)
Out[14]:
state hh_religion urban women_anemic total_children obese_female age years_edu wealth_index hh_members no_children_below5 freq_tv caste_don't know caste_general caste_other backward class caste_sc/st
1 [ap] andhra pradesh hindu 0 0.0 2 0.0 23 2 2 6 2 3.0 0 0 0 1
3 [ap] andhra pradesh hindu 0 1.0 0 0.0 40 0 2 5 0 2.0 0 0 1 0
4 [ap] andhra pradesh hindu 0 0.0 2 0.0 31 7 3 5 0 3.0 0 0 1 0
5 [ap] andhra pradesh muslim 0 0.0 2 0.0 22 0 3 7 2 3.0 0 0 1 0
6 [ap] andhra pradesh hindu 0 0.0 2 1.0 25 6 3 3 0 3.0 0 1 0 0
In [15]:
# Obtain integer encodings for other categorical features
for col in x_cols_categorical:
    dataset[col] = pd.factorize(dataset[col])[0]
print("Categoricals to int encodings: ", dataset.shape)
Categoricals to int encodings:  (77762, 16)
In [16]:
dataset.head()
Out[16]:
state hh_religion urban women_anemic total_children obese_female age years_edu wealth_index hh_members no_children_below5 freq_tv caste_don't know caste_general caste_other backward class caste_sc/st
1 0 0 0 0.0 2 0.0 23 2 2 6 2 3.0 0 0 0 1
3 0 0 0 1.0 0 0.0 40 0 2 5 0 2.0 0 0 1 0
4 0 0 0 0.0 2 0.0 31 7 3 5 0 3.0 0 0 1 0
5 0 1 0 0.0 2 0.0 22 0 3 7 2 3.0 0 0 1 0
6 0 0 0 0.0 2 1.0 25 6 3 3 0 3.0 0 1 0 0
In [17]:
# Create Training, Validation and Test sets
X_train, X_test, Y_train, Y_test, W_train, W_test = train_test_split(dataset, targets, sample_weights, test_size=0.05, random_state=random_state, stratify=targets)
# X_train, X_val, Y_train, Y_val, W_train, W_val = train_test_split(X_train, Y_train, W_train, test_size=0.1)
print("Training set: ", X_train.shape, Y_train.shape, W_train.shape)
# print("Validation set: ", X_val.shape, Y_val.shape, W_val.shape)
print("Test set: ", X_test.shape, Y_test.shape, W_test.shape)
train_cw = compute_class_weight("balanced", classes=np.unique(Y_train), y=Y_train)
print("Class weights: ", train_cw)
Training set:  (73873, 16) (73873,) (73873,)
Test set:  (3889, 16) (3889,) (3889,)
Class weights:  [0.77708702 1.40224365]

Build LightGBM Classifier

In [18]:
# # Define LightGBM Classifier
# model = LGBMClassifier(boosting_type='gbdt', 
#                        feature_fraction=0.8,  
#                        learning_rate=0.01,
#                        max_bins=64,
#                        max_depth=-1,
#                        min_child_weight=0.001,
#                        min_data_in_leaf=50,
#                        min_split_gain=0.0,
#                        num_iterations=1000,
#                        num_leaves=64,
#                        reg_alpha=0,
#                        reg_lambda=1,
#                        subsample_for_bin=200000,
#                        is_unbalance=True,
#                        random_state=random_state, 
#                        n_jobs=n_jobs_clf, 
#                        silent=True, 
#                        importance_type='split')
In [19]:
# # Fit model on training set
# model.fit(X_train, Y_train, sample_weight=W_train.values, 
#           #categorical_feature=x_cols_categorical,
#           categorical_feature=[])
In [20]:
# # Make predictions on Test set
# predictions = model.predict(X_test)
# print(accuracy_score(Y_test, predictions))
# print(f1_score(Y_test, predictions))
# print(confusion_matrix(Y_test, predictions))
# print(classification_report(Y_test, predictions))
In [21]:
# # Save trained model
# dump(model, f'models/{target_col}-{year}-model.joblib')
# del model
In [22]:
# # Define hyperparameter grid
# param_grid = {
#     'num_leaves': [8, 32, 64],
#     'min_data_in_leaf': [10, 20, 50],
#     'max_depth': [-1], 
#     'learning_rate': [0.01, 0.1], 
#     'num_iterations': [1000, 3000, 5000], 
#     'subsample_for_bin': [200000],
#     'min_split_gain': [0.0], 
#     'min_child_weight': [0.001],
#     'feature_fraction': [0.8, 1.0], 
#     'reg_alpha': [0], 
#     'reg_lambda': [0, 1],
#     'max_bin': [64, 128, 255]
# }
In [23]:
# # Define LightGBM Classifier
# clf = LGBMClassifier(boosting_type='gbdt',
#                      objective='binary', 
#                      is_unbalance=True,
#                      random_state=random_state,
#                      n_jobs=n_jobs_clf, 
#                      silent=True, 
#                      importance_type='split')

# # Define K-fold cross validation splitter
# kfold = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=random_state)

# # Perform grid search
# model = GridSearchCV(clf, param_grid=param_grid, scoring='f1', n_jobs=n_jobs_cv, cv=kfold, refit=True, verbose=3)
# model.fit(X_train, Y_train, 
#           sample_weight=W_train.values, 
#           #categorical_feature=x_cols_categorical,
#           categorical_feature=[])

# print('\n All results:')
# print(model.cv_results_)
# print('\n Best estimator:')
# print(model.best_estimator_)
# print('\n Best hyperparameters:')
# print(model.best_params_)
In [24]:
# # Make predictions on Test set
# predictions = model.predict(X_test)
# print(accuracy_score(Y_test, predictions))
# print(f1_score(Y_test, predictions, average='micro'))
# print(confusion_matrix(Y_test, predictions))
# print(classification_report(Y_test, predictions))
In [25]:
# # Save trained model
# dump(model, f'models/{target_col}-{year}-gridsearch.joblib')
# del model

Load LightGBM Classifier

In [26]:
model = load(f'models/{target_col}-{year}-model.joblib')
# model = load(f'models/{target_col}-{year}-gridsearch.joblib').best_estimator_
In [27]:
# Sanity check: Make predictions on Test set
predictions = model.predict(X_test)
print(accuracy_score(Y_test, predictions))
print(f1_score(Y_test, predictions))
print(confusion_matrix(Y_test, predictions))
print(classification_report(Y_test, predictions))
0.7359218308048342
0.6670988654781199
[[1833  669]
 [ 358 1029]]
              precision    recall  f1-score   support

           0       0.84      0.73      0.78      2502
           1       0.61      0.74      0.67      1387

   micro avg       0.74      0.74      0.74      3889
   macro avg       0.72      0.74      0.72      3889
weighted avg       0.75      0.74      0.74      3889

In [28]:
# Overfitting check: Make predictions on Train set
predictions = model.predict(X_train)
print(accuracy_score(Y_train, predictions))
print(f1_score(Y_train, predictions))
print(confusion_matrix(Y_train, predictions))
print(classification_report(Y_train, predictions))
0.7616449853126311
0.6967797485793009
[[36034 11498]
 [ 6110 20231]]
              precision    recall  f1-score   support

           0       0.86      0.76      0.80     47532
           1       0.64      0.77      0.70     26341

   micro avg       0.76      0.76      0.76     73873
   macro avg       0.75      0.76      0.75     73873
weighted avg       0.78      0.76      0.77     73873


Visualizations/Explainations

Note that these plot just explain how the XGBoost model works, not nessecarily how reality works. Since the XGBoost model is trained from observational data, it is not nessecarily a causal model, and so just because changing a factor makes the model's prediction of winning go up, does not always mean it will raise your actual chances.

In [29]:
# print the JS visualization code to the notebook
shap.initjs()

What makes a measure of feature importance good or bad?

  1. Consistency: Whenever we change a model such that it relies more on a feature, then the attributed importance for that feature should not decrease.
  2. Accuracy. The sum of all the feature importances should sum up to the total importance of the model. (For example if importance is measured by the R² value then the attribution to each feature should sum to the R² of the full model)

If consistency fails to hold, then we can’t compare the attributed feature importances between any two models, because then having a higher assigned attribution doesn’t mean the model actually relies more on that feature.

If accuracy fails to hold then we don’t know how the attributions of each feature combine to represent the output of the whole model. We can’t just normalize the attributions after the method is done since this might break the consistency of the method.

Using Tree SHAP for interpretting the model

In [30]:
explainer = shap.TreeExplainer(model)
# shap_values = explainer.shap_values(dataset)
shap_values = pickle.load(open(f'res/{target_col}-{year}-shapvals.obj', 'rb'))
In [31]:
# Visualize a single prediction
shap.force_plot(explainer.expected_value, shap_values[0,:], dataset_display.iloc[0,:])
Out[31]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

The above explanation shows features each contributing to push the model output from the base value (the average model output over the training dataset we passed) to the model output. Features pushing the prediction higher are shown in red, those pushing the prediction lower are in blue.

If we take many explanations such as the one shown above, rotate them 90 degrees, and then stack them horizontally, we can see explanations for an entire dataset (in the notebook this plot is interactive):

In [32]:
# Visualize many predictions
subsample = np.random.choice(len(dataset), 1000)  # Take random sub-sample
shap.force_plot(explainer.expected_value, shap_values[subsample,:], dataset_display.iloc[subsample,:])
Out[32]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Summary Plots

In [33]:
for col, sv in zip(dataset.columns, np.abs(shap_values).mean(0)):
    print(f"{col} - {sv}")
state - 0.36934720473541927
hh_religion - 0.09155813474124229
urban - 0.2123169276367779
women_anemic - 0.008612319573207317
total_children - 0.05607886222772018
obese_female - 0.10276747520660634
age - 0.10897319124813908
years_edu - 0.4759154970408336
wealth_index - 0.5822259860136847
hh_members - 0.053686631796088666
no_children_below5 - 0.15215823191656633
freq_tv - 0.10541708143321463
caste_don't know - 0.0014359959680787695
caste_general - 0.1368569408264368
caste_other backward class - 0.009570436373741685
caste_sc/st - 0.05798589187310115
In [34]:
shap.summary_plot(shap_values, dataset, plot_type="bar")

The above figure shows the global mean(|Tree SHAP|) method applied to our model.

The x-axis is essentially the average magnitude change in model output when a feature is “hidden” from the model (for this model the output has log-odds units). “Hidden” means integrating the variable out of the model. Since the impact of hiding a feature changes depending on what other features are also hidden, Shapley values are used to enforce consistency and accuracy.

However, since we now have individualized explanations for every person in our dataset, to get an overview of which features are most important for a model we can plot the SHAP values of every feature for every sample. The plot below sorts features by the sum of SHAP value magnitudes over all samples, and uses SHAP values to show the distribution of the impacts each feature has on the model output. The color represents the feature value (red high, blue low):

In [35]:
shap.summary_plot(shap_values, dataset_display)
  • Every person has one dot on each row.
  • The x position of the dot is the impact of that feature on the model’s prediction for the person.
  • The color of the dot represents the value of that feature for the customer. Categorical variables are colored grey.
  • Dots that don’t fit on the row pile up to show density (since our dataset is large).
  • Since the XGBoost model has a logistic loss the x-axis has units of log-odds (Tree SHAP explains the change in the margin output of the model).

How to use this: We can make analysis similar to the blog post for interpretting our models.


SHAP Dependence Plots

Next, to understand how a single feature effects the output of the model we can plot the SHAP value of that feature vs. the value of the feature for all the examples in a dataset. SHAP dependence plots show the effect of a single feature across the whole dataset. They plot a feature's value vs. the SHAP value of that feature across many samples.

SHAP dependence plots are similar to partial dependence plots, but account for the interaction effects present in the features, and are only defined in regions of the input space supported by data. The vertical dispersion of SHAP values at a single feature value is driven by interaction effects, and another feature is chosen for coloring to highlight possible interactions. One the benefits of SHAP dependence plots over traditional partial dependence plots is this ability to distigush between between models with and without interaction terms. In other words, SHAP dependence plots give an idea of the magnitude of the interaction terms through the vertical variance of the scatter plot at a given feature value.

Good example of using Dependency Plots: https://slundberg.github.io/shap/notebooks/League%20of%20Legends%20Win%20Prediction%20with%20XGBoost.html

Plots for 'age'

In [36]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('age', 'age'),
         ('age', 'urban'),
         ('age', 'caste_sc/st'),
         ('age', 'caste_general'),
         ('age', 'wealth_index'),
         ('age', 'years_edu'),
         ('age', 'no_children_below5'),
         ('age', 'total_children'),
         ('hh_religion', 'age'),
         ('state', 'age')]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: age, Interaction Feature: age
Feature: age, Interaction Feature: urban
Feature: age, Interaction Feature: caste_sc/st
Feature: age, Interaction Feature: caste_general
Feature: age, Interaction Feature: wealth_index
Feature: age, Interaction Feature: years_edu
Feature: age, Interaction Feature: no_children_below5
Feature: age, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: age
Feature: state, Interaction Feature: age

Plots for 'wealth_index'

In [37]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('wealth_index', 'wealth_index'),
         ('wealth_index', 'age'), 
         ('wealth_index', 'urban'),
         ('wealth_index', 'caste_sc/st'),
         ('wealth_index', 'caste_general'),
         ('wealth_index', 'years_edu'),
         ('wealth_index', 'no_children_below5'),
         ('wealth_index', 'total_children'),
         ('hh_religion', 'wealth_index'),
         ('state', 'wealth_index')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: wealth_index, Interaction Feature: wealth_index
Feature: wealth_index, Interaction Feature: age
Feature: wealth_index, Interaction Feature: urban
Feature: wealth_index, Interaction Feature: caste_sc/st
Feature: wealth_index, Interaction Feature: caste_general
Feature: wealth_index, Interaction Feature: years_edu
Feature: wealth_index, Interaction Feature: no_children_below5
Feature: wealth_index, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: wealth_index
Feature: state, Interaction Feature: wealth_index

Plots for 'years_edu'

In [38]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('years_edu', 'years_edu'),
         ('years_edu', 'age'), 
         ('years_edu', 'urban'),
         ('years_edu', 'caste_sc/st'),
         ('years_edu', 'caste_general'),
         ('years_edu', 'wealth_index'),
         ('years_edu', 'no_children_below5'),
         ('years_edu', 'total_children'),
         ('hh_religion', 'years_edu'),
         ('state', 'years_edu')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: years_edu, Interaction Feature: years_edu
Feature: years_edu, Interaction Feature: age
Feature: years_edu, Interaction Feature: urban
Feature: years_edu, Interaction Feature: caste_sc/st
Feature: years_edu, Interaction Feature: caste_general
Feature: years_edu, Interaction Feature: wealth_index
Feature: years_edu, Interaction Feature: no_children_below5
Feature: years_edu, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: years_edu
Feature: state, Interaction Feature: years_edu

Plots for 'caste_sc/st'

In [39]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_sc/st', 'caste_sc/st'),
         ('caste_sc/st', 'age'), 
         ('caste_sc/st', 'urban'),
         ('caste_sc/st', 'years_edu'),
         ('caste_sc/st', 'wealth_index'),
         ('caste_sc/st', 'no_children_below5'),
         ('caste_sc/st', 'total_children'),
         ('hh_religion', 'caste_sc/st'),
         ('state', 'caste_sc/st')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: caste_sc/st, Interaction Feature: caste_sc/st
Feature: caste_sc/st, Interaction Feature: age
Feature: caste_sc/st, Interaction Feature: urban
Feature: caste_sc/st, Interaction Feature: years_edu
Feature: caste_sc/st, Interaction Feature: wealth_index
Feature: caste_sc/st, Interaction Feature: no_children_below5
Feature: caste_sc/st, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: caste_sc/st
Feature: state, Interaction Feature: caste_sc/st

Plots for 'caste_general'

In [40]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_general', 'caste_general'),
         ('caste_general', 'age'), 
         ('caste_general', 'urban'),
         ('caste_general', 'years_edu'),
         ('caste_general', 'wealth_index'),
         ('caste_general', 'no_children_below5'),
         ('caste_general', 'total_children'),
         ('hh_religion', 'caste_general'),
         ('state', 'caste_general')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: caste_general, Interaction Feature: caste_general
Feature: caste_general, Interaction Feature: age
Feature: caste_general, Interaction Feature: urban
Feature: caste_general, Interaction Feature: years_edu
Feature: caste_general, Interaction Feature: wealth_index
Feature: caste_general, Interaction Feature: no_children_below5
Feature: caste_general, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: caste_general
Feature: state, Interaction Feature: caste_general

Visualizing Bar/Summary plots split by age bins

In [41]:
bins = [(21,25), (26,30), (31,35), (36,40), (41,45), (46,50)]

for low, high in bins:
    # Sample dataset by age range
    dataset_sample = dataset[(dataset.age > low) & (dataset.age <= high)]
    dataset_display_sample = dataset_display[(dataset.age > low) & (dataset.age <= high)]
    targets_sample = targets[(dataset.age > low) & (dataset.age <= high)]
    shap_values_sample = shap_values[(dataset.age > low) & (dataset.age <= high)]
    
    print("\nAge Range: {} - {} years".format(low, high))
    print("Sample size: {}\n".format(len(dataset_sample)))
    
    for col, sv in zip(dataset_sample.columns, np.abs(shap_values_sample).mean(0)):
        print(f"{col} - {sv}")
    
    # Summary plots
    shap.summary_plot(shap_values_sample, dataset_sample, plot_type="bar")
    shap.summary_plot(shap_values_sample, dataset_display_sample)
Age Range: 21 - 25 years
Sample size: 13487

state - 0.3715626599828414
hh_religion - 0.0859625776036463
urban - 0.21749103946755033
women_anemic - 0.009451313510487155
total_children - 0.08366387208409257
obese_female - 0.07224041765948605
age - 0.16578980115648223
years_edu - 0.5029625058799383
wealth_index - 0.5457597779427326
hh_members - 0.054795807080076094
no_children_below5 - 0.1770701703236233
freq_tv - 0.10066990772587125
caste_don't know - 0.0013806969326833915
caste_general - 0.11716466021404015
caste_other backward class - 0.009071794156521458
caste_sc/st - 0.04806874715945885
Age Range: 26 - 30 years
Sample size: 13101

state - 0.3712619459465034
hh_religion - 0.0920491883022542
urban - 0.21548485568355102
women_anemic - 0.00947838651878436
total_children - 0.05249654608834615
obese_female - 0.0924564262075922
age - 0.052356778889163444
years_edu - 0.48863789367082916
wealth_index - 0.5773068742871454
hh_members - 0.051027201913860114
no_children_below5 - 0.1794027539855777
freq_tv - 0.11011824522850011
caste_don't know - 0.001464001280603125
caste_general - 0.1334645336369998
caste_other backward class - 0.009516395592122653
caste_sc/st - 0.05800089902442302
Age Range: 31 - 35 years
Sample size: 12209

state - 0.3678917790550324
hh_religion - 0.08833730895721567
urban - 0.21064515276810789
women_anemic - 0.008430125421565477
total_children - 0.04499526144885559
obese_female - 0.10795878472835697
age - 0.12122372233145487
years_edu - 0.4638978613135957
wealth_index - 0.5865134990356844
hh_members - 0.05266272388964973
no_children_below5 - 0.16391130610095989
freq_tv - 0.11092513566082407
caste_don't know - 0.0014779602081806261
caste_general - 0.14443903568754463
caste_other backward class - 0.009273816203793046
caste_sc/st - 0.059368852885641296
Age Range: 36 - 40 years
Sample size: 10658

state - 0.3655054517795273
hh_religion - 0.09029897374272683
urban - 0.2078267861350017
women_anemic - 0.007549383948100648
total_children - 0.04686958815634176
obese_female - 0.11721240551621977
age - 0.10405578040956877
years_edu - 0.45150476386098926
wealth_index - 0.5909363530607962
hh_members - 0.05361984134792487
no_children_below5 - 0.13154077776120535
freq_tv - 0.10735859080345912
caste_don't know - 0.001399523293498242
caste_general - 0.14468962581669512
caste_other backward class - 0.010729836843497473
caste_sc/st - 0.06293205994970028
Age Range: 41 - 45 years
Sample size: 8613

state - 0.3677645411455765
hh_religion - 0.10118845861172411
urban - 0.21245633653335455
women_anemic - 0.007632714242114331
total_children - 0.043795957398847785
obese_female - 0.13014245645896563
age - 0.07580513918503683
years_edu - 0.4520806926376144
wealth_index - 0.6089550275080123
hh_members - 0.05522424255853601
no_children_below5 - 0.10979232134227883
freq_tv - 0.10089043150651653
caste_don't know - 0.0015609794401015874
caste_general - 0.15001989740844374
caste_other backward class - 0.009946430759123853
caste_sc/st - 0.06664787213157282
Age Range: 46 - 50 years
Sample size: 4290

state - 0.3550368168469482
hh_religion - 0.10679353961736636
urban - 0.21043329873496208
women_anemic - 0.008101837943367897
total_children - 0.04079285807480538
obese_female - 0.14170578386653426
age - 0.12451236573681664
years_edu - 0.46424886961081097
wealth_index - 0.6353010395466357
hh_members - 0.06146534322802115
no_children_below5 - 0.09738805162508182
freq_tv - 0.09202744521893268
caste_don't know - 0.0012748033768087566
caste_general - 0.1512182939998681
caste_other backward class - 0.009497386847711396
caste_sc/st - 0.06581630520054238

SHAP Interaction Values

SHAP interaction values are a generalization of SHAP values to higher order interactions.

The model returns a matrix for every prediction, where the main effects are on the diagonal and the interaction effects are off-diagonal. The main effects are similar to the SHAP values you would get for a linear model, and the interaction effects captures all the higher-order interactions are divide them up among the pairwise interaction terms.

Note that the sum of the entire interaction matrix is the difference between the model's current output and expected output, and so the interaction effects on the off-diagonal are split in half (since there are two of each). When plotting interaction effects the SHAP package automatically multiplies the off-diagonal values by two to get the full interaction effect.

In [42]:
# Sample from dataset based on sample weights
dataset_ss = dataset.sample(10000, weights=sample_weights, random_state=random_state)
print(dataset_ss.shape)
dataset_display_ss = dataset_display.loc[dataset_ss.index]
print(dataset_display_ss.shape)
(10000, 16)
(10000, 16)
In [43]:
# Compute SHAP interaction values (time consuming)
# shap_interaction_values = explainer.shap_interaction_values(dataset_ss)
shap_interaction_values = pickle.load(open(f'res/{target_col}-{year}-shapints.obj', 'rb'))
In [44]:
shap.summary_plot(shap_interaction_values, dataset_display_ss, max_display=15)

Heatmap of SHAP Interaction Values

In [45]:
tmp = np.abs(shap_interaction_values).sum(0)
for i in range(tmp.shape[0]):
    tmp[i,i] = 0
inds = np.argsort(-tmp.sum(0))[:50]
tmp2 = tmp[inds,:][:,inds]
pl.figure(figsize=(12,12))
pl.imshow(tmp2)
pl.yticks(range(tmp2.shape[0]), dataset_ss.columns[inds], rotation=50.4, horizontalalignment="right")
pl.xticks(range(tmp2.shape[0]), dataset_ss.columns[inds], rotation=50.4, horizontalalignment="left")
pl.gca().xaxis.tick_top()
pl.show()

SHAP Interaction Value Dependence Plots

Running a dependence plot on the SHAP interaction values a allows us to separately observe the main effects and the interaction effects.

Below we plot the main effects for age and some of the interaction effects for age. It is informative to compare the main effects plot of age with the earlier SHAP value plot for age. The main effects plot has no vertical dispersion because the interaction effects are all captured in the off-diagonal terms.

Good example of how to infer interesting stuff from interaction values: https://slundberg.github.io/shap/notebooks/NHANES%20I%20Survival%20Model.html

In [46]:
shap.dependence_plot(
    ("age", "age"), 
    shap_interaction_values, dataset_ss, display_features=dataset_display_ss
)

Now we plot the interaction effects involving age (and other features after that). These effects capture all of the vertical dispersion that was present in the original SHAP plot but is missing from the main effects plot above.

Plots for 'age'

In [47]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('age', 'age'),
         ('age', 'urban'),
         ('age', 'caste_sc/st'),
         ('age', 'caste_general'),
         ('age', 'wealth_index'),
         ('age', 'years_edu'),
         ('age', 'no_children_below5'),
         ('age', 'total_children')]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: age, Interaction Feature: age
Feature: age, Interaction Feature: urban
Feature: age, Interaction Feature: caste_sc/st
Feature: age, Interaction Feature: caste_general
Feature: age, Interaction Feature: wealth_index
Feature: age, Interaction Feature: years_edu
Feature: age, Interaction Feature: no_children_below5
Feature: age, Interaction Feature: total_children

Plots for 'wealth_index'

In [48]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('wealth_index', 'wealth_index'),
         ('wealth_index', 'age'), 
         ('wealth_index', 'urban'),
         ('wealth_index', 'caste_sc/st'),
         ('wealth_index', 'caste_general'),
         ('wealth_index', 'years_edu'),
         ('wealth_index', 'no_children_below5'),
         ('wealth_index', 'total_children')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: wealth_index, Interaction Feature: wealth_index
Feature: wealth_index, Interaction Feature: age
Feature: wealth_index, Interaction Feature: urban
Feature: wealth_index, Interaction Feature: caste_sc/st
Feature: wealth_index, Interaction Feature: caste_general
Feature: wealth_index, Interaction Feature: years_edu
Feature: wealth_index, Interaction Feature: no_children_below5
Feature: wealth_index, Interaction Feature: total_children

Plots for 'years_edu'

In [49]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('years_edu', 'years_edu'),
         ('years_edu', 'age'), 
         ('years_edu', 'urban'),
         ('years_edu', 'caste_sc/st'),
         ('years_edu', 'caste_general'),
         ('years_edu', 'wealth_index'),
         ('years_edu', 'no_children_below5'),
         ('years_edu', 'total_children')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: years_edu, Interaction Feature: years_edu
Feature: years_edu, Interaction Feature: age
Feature: years_edu, Interaction Feature: urban
Feature: years_edu, Interaction Feature: caste_sc/st
Feature: years_edu, Interaction Feature: caste_general
Feature: years_edu, Interaction Feature: wealth_index
Feature: years_edu, Interaction Feature: no_children_below5
Feature: years_edu, Interaction Feature: total_children

Plots for 'caste_sc/st'

In [50]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_sc/st', 'caste_sc/st'),
         ('caste_sc/st', 'age'), 
         ('caste_sc/st', 'urban'),
         ('caste_sc/st', 'years_edu'),
         ('caste_sc/st', 'wealth_index'),
         ('caste_sc/st', 'no_children_below5'),
         ('caste_sc/st', 'total_children')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: caste_sc/st, Interaction Feature: caste_sc/st
Feature: caste_sc/st, Interaction Feature: age
Feature: caste_sc/st, Interaction Feature: urban
Feature: caste_sc/st, Interaction Feature: years_edu
Feature: caste_sc/st, Interaction Feature: wealth_index
Feature: caste_sc/st, Interaction Feature: no_children_below5
Feature: caste_sc/st, Interaction Feature: total_children

Plots for 'caste_general'

In [51]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_general', 'caste_general'),
         ('caste_general', 'age'), 
         ('caste_general', 'urban'),
         ('caste_general', 'years_edu'),
         ('caste_general', 'wealth_index'),
         ('caste_general', 'no_children_below5'),
         ('caste_general', 'total_children'),
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: caste_general, Interaction Feature: caste_general
Feature: caste_general, Interaction Feature: age
Feature: caste_general, Interaction Feature: urban
Feature: caste_general, Interaction Feature: years_edu
Feature: caste_general, Interaction Feature: wealth_index
Feature: caste_general, Interaction Feature: no_children_below5
Feature: caste_general, Interaction Feature: total_children